import argparse
import torch
from embedding_optimizer import EmbeddingOptimizer
from helper import load_image, save_image_file, inverse_normalize
from models.FinetuneVTmodels import *
from models.MIL_VT import *
from models.FinetuneVTmodels import MIL_VT_FineTune

def main():
    parser = argparse.ArgumentParser(description="Optimize Single Image Embedding")
    parser.add_argument("--current_image_path", type=str, required=True, help="File path of the current image")
    parser.add_argument("--target_image_path", type=str, required=True, help="File path of the target image")
    parser.add_argument("--learning_rate", type=float, default=0.08, help="Learning rate for gradient descent")
    parser.add_argument("--l2_dist_threshold", type=float, default=1e-4, help="Squared L2 distance threshold")
    parser.add_argument("--cosine_sim_threshold", type=float, default=0.97, help="Cosine similarity threshold")
    parser.add_argument("--mil_emb", type=boolean, default=False, help="Choice of embedding type")
    parser.add_argument("--output_path", type=str, required=True, help="File path to save the optimized image")
    args = parser.parse_args()

    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model =MIL_VT_FineTune()
    checkpoint_path = 'path/to/saved_weight/file/'
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.to(device)


    
    preprocess = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
    epsilon_value = 0.02  #for PGD 
    optimizer = EmbeddingOptimizer(model, args.learning_rate,epsilon_value)

    current_image = load_image(args.current_image_path, preprocess, device)
    target_image = load_image(args.target_image_path, preprocess, device)

    target_vit_embedding,target_mil_embedding= model.forward_features(target_image)

    optimized_image, _, _, _ = optimizer.optimize_embeddings_pgd(current_image, target_vit_embedding, args.l2_dist_threshold, args.cosine_sim_threshold,mil_emb)  #For MIL embedding, you need to pass target_mil_embedding
    optimized_image_inv = inverse_normalize()(optimized_image)

    save_image_file(optimized_image_inv, args.output_path, "optimized_image")

if __name__ == "__main__":
    main()

